# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
from typing import Annotated, cast

from fastapi import Depends, HTTPException, Query, Response, status
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.connections import (
    ConnectionBody,
    ConnectionBulkBody,
    ConnectionCollectionResponse,
    ConnectionResponse,
    ConnectionTestResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.configuration import conf
from airflow.models import Connection
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.utils.db import create_default_connections as db_create_default_connections
from airflow.utils.strings import get_random_string

connections_router = AirflowRouter(tags=["Connection"], prefix="/connections")


@connections_router.delete(
    "/{connection_id}",
    status_code=status.HTTP_204_NO_CONTENT,
    responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def delete_connection(
    connection_id: str,
    session: SessionDep,
):
    """Delete a connection entry."""
    connection = session.scalar(select(Connection).filter_by(conn_id=connection_id))

    if connection is None:
        raise HTTPException(
            status.HTTP_404_NOT_FOUND, f"The Connection with connection_id: `{connection_id}` was not found"
        )

    session.delete(connection)


@connections_router.get(
    "/{connection_id}",
    responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_connection(
    connection_id: str,
    session: SessionDep,
) -> ConnectionResponse:
    """Get a connection entry."""
    connection = session.scalar(select(Connection).filter_by(conn_id=connection_id))

    if connection is None:
        raise HTTPException(
            status.HTTP_404_NOT_FOUND, f"The Connection with connection_id: `{connection_id}` was not found"
        )

    return connection


@connections_router.get(
    "",
    responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
)
def get_connections(
    limit: QueryLimit,
    offset: QueryOffset,
    order_by: Annotated[
        SortParam,
        Depends(
            SortParam(
                ["conn_id", "conn_type", "description", "host", "port", "id"],
                Connection,
                {"connection_id": "conn_id"},
            ).dynamic_depends()
        ),
    ],
    session: SessionDep,
) -> ConnectionCollectionResponse:
    """Get all connection entries."""
    connection_select, total_entries = paginated_select(
        statement=select(Connection),
        order_by=order_by,
        offset=offset,
        limit=limit,
        session=session,
    )

    connections = session.scalars(connection_select)

    return ConnectionCollectionResponse(
        connections=connections,
        total_entries=total_entries,
    )


@connections_router.post(
    "",
    status_code=status.HTTP_201_CREATED,
    responses=create_openapi_http_exception_doc(
        [status.HTTP_409_CONFLICT]
    ),  # handled by global exception handler
)
def post_connection(
    post_body: ConnectionBody,
    session: SessionDep,
) -> ConnectionResponse:
    """Create connection entry."""
    connection = Connection(**post_body.model_dump(by_alias=True))
    session.add(connection)
    return connection


@connections_router.put(
    "/bulk",
    responses={
        **create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]),
        status.HTTP_201_CREATED: {
            "description": "Created",
            "model": ConnectionCollectionResponse,
        },
        status.HTTP_200_OK: {
            "description": "Created with overwrite",
            "model": ConnectionCollectionResponse,
        },
    },
)
def put_connections(
    response: Response,
    post_body: ConnectionBulkBody,
    session: SessionDep,
) -> ConnectionCollectionResponse:
    """Create connection entry."""
    response.status_code = status.HTTP_201_CREATED if not post_body.overwrite else status.HTTP_200_OK
    connections: list[Connection]
    if not post_body.overwrite:
        connections = [Connection(**body.model_dump(by_alias=True)) for body in post_body.connections]
        session.add_all(connections)
    else:
        connection_ids = [conn.connection_id for conn in post_body.connections]
        existed_connections = session.execute(
            select(Connection).filter(Connection.conn_id.in_(connection_ids))
        ).scalars()
        existed_connections_dict = {conn.conn_id: conn for conn in existed_connections}
        connections = []
        # if conn_id exists, update the corresponding connection, else add a new connection
        for body in post_body.connections:
            if body.connection_id in existed_connections_dict:
                connection = existed_connections_dict[body.connection_id]
                for key, val in body.model_dump(by_alias=True).items():
                    setattr(connection, key, val)
                connections.append(connection)
            else:
                connections.append(Connection(**body.model_dump(by_alias=True)))
        session.add_all(connections)
    return ConnectionCollectionResponse(
        connections=cast(list[ConnectionResponse], connections),
        total_entries=len(connections),
    )


@connections_router.patch(
    "/{connection_id}",
    responses=create_openapi_http_exception_doc(
        [
            status.HTTP_400_BAD_REQUEST,
            status.HTTP_404_NOT_FOUND,
        ]
    ),
)
def patch_connection(
    connection_id: str,
    patch_body: ConnectionBody,
    session: SessionDep,
    update_mask: list[str] | None = Query(None),
) -> ConnectionResponse:
    """Update a connection entry."""
    if patch_body.connection_id != connection_id:
        raise HTTPException(
            status.HTTP_400_BAD_REQUEST,
            "The connection_id in the request body does not match the URL parameter",
        )

    non_update_fields = {"connection_id", "conn_id"}
    connection = session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))

    if connection is None:
        raise HTTPException(
            status.HTTP_404_NOT_FOUND, f"The Connection with connection_id: `{connection_id}` was not found"
        )

    fields_to_update = patch_body.model_fields_set

    if update_mask:
        fields_to_update = fields_to_update.intersection(update_mask)
        data = patch_body.model_dump(include=fields_to_update - non_update_fields, by_alias=True)
    else:
        try:
            ConnectionBody(**patch_body.model_dump())
        except ValidationError as e:
            raise RequestValidationError(errors=e.errors())
        data = patch_body.model_dump(exclude=non_update_fields, by_alias=True)

    for key, val in data.items():
        setattr(connection, key, val)

    return connection


@connections_router.post(
    "/test",
)
def test_connection(
    test_body: ConnectionBody,
) -> ConnectionTestResponse:
    """
    Test an API connection.

    This method first creates an in-memory transient conn_id & exports that to an env var,
    as some hook classes tries to find out the `conn` from their __init__ method & errors out if not found.
    It also deletes the conn id env variable after the test.
    """
    if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled":
        raise HTTPException(
            status.HTTP_403_FORBIDDEN,
            "Testing connections is disabled in Airflow configuration. "
            "Contact your deployment admin to enable it.",
        )

    transient_conn_id = get_random_string()
    conn_env_var = f"{CONN_ENV_PREFIX}{transient_conn_id.upper()}"
    try:
        data = test_body.model_dump(by_alias=True)
        data["conn_id"] = transient_conn_id
        conn = Connection(**data)
        os.environ[conn_env_var] = conn.get_uri()
        test_status, test_message = conn.test_connection()
        return ConnectionTestResponse.model_validate({"status": test_status, "message": test_message})
    finally:
        os.environ.pop(conn_env_var, None)


@connections_router.post(
    "/defaults",
    status_code=status.HTTP_204_NO_CONTENT,
)
def create_default_connections(
    session: SessionDep,
):
    """Create default connections."""
    db_create_default_connections(session)
